""" Helper Modules """

from functools import partial
from typing import Any, Tuple, List, Dict, Union, Type, Optional, Callable

import jax
import numpy as np
import jax.numpy as jnp
from jax import nn
import haiku as hk
import einops

from diffgro.common.models.utils import (
    act2fn,
    init_he_uniform,
    init_he_normal,
)


# ============================== Time Embeddings ============================= #


class SinPosEmb(hk.Module):
    """Sinosidual Positional Embedding"""

    def __init__(self, emb_dim: int):
        super().__init__()
        self.emb_dim = emb_dim

    def __call__(self, x: jax.Array) -> jax.Array:
        half_dim = self.emb_dim // 2
        emb = jnp.log(10000) / (half_dim - 1)
        emb = jnp.exp(jnp.arange(0, half_dim) * -emb)
        emb = x * emb[None, :]
        emb = jnp.concatenate((jnp.sin(emb), jnp.cos(emb)), axis=-1)
        return emb


class TimeEmb(hk.Module):
    """Time Embedding"""

    def __init__(self, emb_dim: int, activation_fn: str = "mish"):
        super().__init__()
        self.emb_dim = emb_dim
        self.activation_fn = activation_fn

    def __call__(self, x: jax.Array) -> jax.Array:
        x = SinPosEmb(self.emb_dim)(x)
        x = hk.Linear(self.emb_dim * 2, name=f"time_emb_0", **init_he_normal())(x)
        x = act2fn[self.activation_fn](x)
        x = hk.Linear(self.emb_dim, name=f"time_emb_1", **init_he_normal())(x)
        return x


# ============================== Conv Helpers ============================= #


class DownSample1D(hk.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def __call__(self, x: jax.Array):
        return hk.Conv1D(self.dim, kernel_shape=3, stride=2, data_format="NCW")(x)


class UpSample1D(hk.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def __call__(self, x: jax.Array):
        return hk.Conv1DTranspose(
            self.dim, kernel_shape=4, stride=2, data_format="NCW"
        )(x)


class Conv1DBlock(hk.Module):
    """Conv1D -> GroupNorm -> Mish"""

    def __init__(self, out_channels: int, kernel_size: int, n_groups=8):
        super().__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.n_groups = n_groups

    def __call__(self, x: jax.Array):
        batch_size = x.shape[0]

        x = hk.Conv1D(
            self.out_channels, kernel_shape=self.kernel_size, data_format="NCW"
        )(x)
        x = einops.rearrange(x, "b c h -> b c 1 h")
        x = hk.GroupNorm(self.n_groups, data_format="NCW")(x)
        x = einops.rearrange(x, "b c 1 h -> b c h")
        x = act2fn["mish"](x)
        return x


class Residual(hk.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def __call__(self, x, *args, **kwargs) -> jax.Array:
        return self.fn(x, *args, **kwargs) + x


class PreNorm(hk.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn

    def __call__(self, x) -> jax.Array:
        x = hk.LayerNorm(axis=1, create_scale=True, create_offset=True)(x)
        return self.fn(x)


# ============================== Default MLP Layers ============================= #


class MLP(hk.Module):
    """MLP Layer"""

    def __init__(
        self,
        emb_dim: int,
        out_dim: int,
        net_arch: List[int],
        batch_keys: List[int],
        activation_fn: str = "mish",
        squash_output: bool = False,
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.out_dim = out_dim
        self.net_arch = net_arch
        self.batch_keys = batch_keys
        self.activation_fn = activation_fn
        self.squash_output = squash_output

    def __call__(self, batch_dict: Dict[str, jax.Array]) -> jax.Array:
        inp = []
        for key in self.batch_keys:
            c = batch_dict[key]
            c = hk.Linear(
                self.emb_dim * 2, name=f"mlp_{key}_emb_0", **init_he_normal()
            )(c)
            c = act2fn[self.activation_fn](c)
            c = hk.Linear(self.emb_dim, name=f"mlp_{key}_emb_1", **init_he_normal())(c)
            inp.append(c)

        # -> [batch, dim]
        inp = jnp.concatenate(inp, axis=-1)

        for ind, dim in enumerate(self.net_arch):
            inp = hk.Linear(dim, name=f"mlp_{ind}", **init_he_normal())(inp)
            inp = act2fn[self.activation_fn](inp)

        out = inp
        if self.out_dim != -1:
            out = hk.Linear(self.out_dim, name=f"mlp_last", **init_he_normal())(out)

        if self.squash_output:
            out = nn.tanh(out)
        return out


def _layer_norm(x: jax.Array) -> jax.Array:
    """Applies a unique LayerNorm to `x` with default settings."""
    ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
    return ln(x)


class Transformer(hk.Module):

    def __init__(
        self,
        num_heads: int,
        num_layers: int,
        dropout_rate: float,
        widening_factor: int = 4,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout_rate = dropout_rate
        self.widening_factor = widening_factor

    def __call__(
        self,
        embeddings: jax.Array,
    ) -> jax.Array:
        initializer = hk.initializers.VarianceScaling(2 / self.num_layers)
        _, seq_len, emb_size = embeddings.shape

        mask = np.tril(np.ones((1, 1, seq_len, seq_len)))
        h = embeddings
        for _ in range(self.num_layers):
            attn_block = hk.MultiHeadAttention(
                num_heads=self.num_heads,
                key_size=emb_size // self.num_heads,
                model_size=emb_size,
                w_init=initializer,
            )
            h_norm = _layer_norm(h)
            h_attn = attn_block(h_norm, h_norm, h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_attn)
            h = h + h_attn

            dense_block = hk.Sequential(
                [
                    hk.Linear(self.widening_factor * emb_size, w_init=initializer),
                    jax.nn.gelu,
                    hk.Linear(emb_size, w_init=initializer),
                ]
            )
            h_norm = _layer_norm(h)
            h_dense = dense_block(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_dense)
            h = h + h_dense

        return _layer_norm(h)
